#include "BinarySearchTree.hpp"
#include <stdlib.h>
#include <stdio.h>
#include <string.h>

BinarySearchNode *buildNode(const char *data)
{
    BinarySearchNode *newNode = (BinarySearchNode *)malloc(sizeof(BinarySearchNode));
    newNode->data = data;
    newNode->left = nullptr;
    newNode->right = nullptr;
    return newNode;
}

BinarySearchNode *findMaxNodeAndDelete(BinarySearchNode *p)
{
    if (p->right == nullptr)
        return p;
    BinarySearchNode *pp;
    for (pp = p;; pp = p, p = p->right)
    {
        if (p->right == nullptr)
        {
            pp->right = p->left;
            return p;
        }
    }
}

int compre(BinarySearchNode *p, BinarySearchNode *p2)
{
    return strcmp(p->data, p2->data);
}

BinarySearchTree::BinarySearchTree(const char *data)
{
    this->root = buildNode(data);
}

BinarySearchNode *BinarySearchTree::find(const char *data)
{
    for (BinarySearchNode *p = root; p != nullptr;)
    {
        int cmpr = strcmp(p->data, data);
        if (cmpr == 0)
        {
            return p;
        }
        if (cmpr < 0)
            p = p->right;
        else
            p = p->left;
    }
    return nullptr;
}

BinarySearchNode *BinarySearchTree::insertNode(const char *data)
{
    BinarySearchNode *newNode = buildNode(data);
    BinarySearchNode *p = root;
    int cmpr;
    for (;;)
    {
        cmpr = cmpr = strcmp(p->data, data);
        if (cmpr == 0)
            break;
        if (cmpr < 0)
        {
            if (p->right == nullptr)
            {
                p->right = newNode;
                break;
            }
            p = p->right;
        }
        else
        {
            if (p->left == nullptr)
            {
                p->left = newNode;
                break;
            }
            p = p->left;
        }
    }
    return newNode;
}

BinarySearchNode *BinarySearchTree::findMin()
{
    BinarySearchNode *p = root;
    for (; p->left != nullptr; p = p->left)
        ;
    return p;
}

void BinarySearchTree::remove(const char *data)
{
    int cmpr;
    if ((cmpr = strcmp(root->data, data)) == 0)
    {
        if (root->left == nullptr)
        {
            if (root->right == nullptr)
            {
                return;
            }
            BinarySearchNode *tmpn = root->right;
            free(root);
            root = tmpn;
            return;
        }
        if (root->right == nullptr)
        {
            BinarySearchNode *tmpn = root->left;
            free(root);
            root = tmpn;
            return;
        }
        BinarySearchNode *node = findMaxNodeAndDelete(root->left);
        node->left = root->left;
        node->right = root->right;
        free(root);
        root = node;
        return;
    }
    BinarySearchNode *pp = root;
    BinarySearchNode *p = cmpr < 0 ? root->right : root->left;
    for (;; p = cmpr < 0 ? p->right : p->left)
    {
        if (p == nullptr)
        {
            break;
        }
        int right = cmpr < 0;
        cmpr = strcmp(p->data, data);
        if (cmpr == 0)
        {
            printf("remove node(%s)->node(%s %p%p)\n", pp->data, p->data, p->left, p->right);
            if (p->left == nullptr)
            {
                if (p->right == nullptr)
                {
                    free(p);
                    if (right)
                    {
                        pp->right = nullptr;
                    }
                    else
                    {
                        pp->left = nullptr;
                    }
                    break;
                }
                if (right)
                {
                    pp->right = p->left;
                }
                else
                {
                    pp->left = p->right;
                }
                free(p);
                break;
            }
            if (p->right == nullptr)
            {
                if (right)
                {
                    pp->right = p->left;
                }
                else
                {
                    pp->left = p->left;
                }
                free(p);
                break;
            }
            BinarySearchNode *node = findMaxNodeAndDelete(p->left);
            if (right)
            {
                pp->right = node;
            }
            else
            {
                pp->left = node;
            }
            node = p->right;
            node = p->left;
            free(p);
            break;
        }
        pp = p;
    }
}

BinarySearchNode *BinarySearchTree::findMax()
{
    return findMax(root);
}

int BinarySearchTree::isBST(BinarySearchNode *node, const char *prev)
{
    if (node == nullptr)
        return 1;
    if (!isBST(node->left, prev))
        return 0;
    prev = node->data;
    return isBST(node->right, prev);
}
int BinarySearchTree::isBST(BinarySearchNode *node, const char *min, const char *max)
{
    if (node == nullptr)
        return 1;
    return (strcmp(node->data, min) > 0 && strcmp(node->data, max) < 0 && isBST(node->left, min, node->data) && isBST(node->right, node->data, max));
}
int BinarySearchTree::isBST(BinarySearchNode *node)
{
    if (node == nullptr)
    {
        return 1;
    }
    if (node->left != nullptr && compre(findMax(node->left), node->left) < 0)
        return 0;
    if (node->right != nullptr && compre(findMin(node->right), node->right) > 0)
        return 0;
    if (!isBST(node->left) || !isBST(node->right))
    {
        return 0;
    }
    return 1;
}

BinarySearchNode *BinarySearchTree::findMin(BinarySearchNode *node)
{
    if (node == nullptr)
    {
        return nullptr;
    }
    if (node->left == nullptr)

        return node;
    else
        return findMin(node->left);
}
BinarySearchNode *BinarySearchTree::findMax(BinarySearchNode *node)
{
    if (node == nullptr)
    {
        return nullptr;
    }
    if (node->right == nullptr)

        return node;
    else
        return findMax(node->right);
}
BinarySearchNode *BinarySearchTree::BST2DLL(BinarySearchNode *root, BinarySearchNode *Ltail)
{
    BinarySearchNode *left,
        *ltail,
        *right,
        *rtail;
    if (root == nullptr)
    {
        ltail = nullptr;
        return nullptr;
    }
    left = BST2DLL(root->left, ltail);
    right = BST2DLL(root->right, rtail);
    root->left = ltail;
    root->right = right;
    if (right == nullptr)
        ltail = root;
    else
    {
        right->left = root;
        ltail = rtail;
    }
    if (left == nullptr)
        return root;
    else
    {
        ltail->right = root;
        return left;
    }
}
BinarySearchNode *BinarySearchTree::BST2DLL()
{
    return nullptr;
}

int BinarySearchTree::isBST()
{
    return isBST(root);
    // return isBST(root, "0");
    // return isBST(root, "100", "0");
}

BinarySearchNode *BinarySearchTree::pruneBST(BinarySearchNode *node, const char *mindata, const char *maxdata)
{
    if (node == nullptr)
        return nullptr;
    node->left = pruneBST(node->left, mindata, maxdata);
    node->right = pruneBST(node->right, mindata, maxdata);
    printf("strcmp(mindata->%s, node->data->%s) = %d\n", mindata, node->data, strcmp(mindata, node->data));
    printf("strcmp(maxdata->%s, node->data->%s) = %d\n", maxdata, node->data, strcmp(maxdata, node->data));
    if (strcmp(mindata, node->data) <= 0 && strcmp(maxdata, node->data) >= 0)
        return node;
    if (strcmp(mindata, node->data) > 0)
    {
        BinarySearchNode *right = node->right;
        printf("free(node(%s))\n", node->data);
        return right;
    }
    if (strcmp(maxdata, node->data) < 0)
    {
        BinarySearchNode *left = node->left;
        printf("free(node(%s))\n", node->data);
        return left;
    }
    return node;
}
BinarySearchNode *BinarySearchTree::pruneBST(const char *mindata, const char *maxdata)
{
    return pruneBST(root, mindata, maxdata);
}